[WIP] Add MarianMT to models exportable with ONNX#13854
[WIP] Add MarianMT to models exportable with ONNX#13854Maxinho96 wants to merge 4 commits intohuggingface:masterfrom
Conversation
| _SUPPORTED_MODEL_KIND = { | ||
| "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), | ||
| "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), | ||
| "marian": supported_features_mapping("seq2seq-lm", onnx_config_cls=MarianOnnxConfig), |
There was a problem hiding this comment.
You can also add the "default" task.
| self._setup_normalizer() | ||
|
|
||
| def num_special_tokens_to_add(self, **unused): | ||
| def num_special_tokens_to_add(self, *args, **kwargs): |
There was a problem hiding this comment.
What is the reason for this?
There was a problem hiding this comment.
num_special_tokens_to_add is called here with a positional argument, that causes an error if the function is defined only with keyword arguments **unused.
|
|
||
| >>> last_hidden_states = outputs.last_hidden_state | ||
| """ | ||
| # different to other models, Marian automatically creates decoder_input_ids from |
There was a problem hiding this comment.
@LysandreJik @patil-suraj What do you think?
Just saw this comment, will investigate and come back to you.
There was a problem hiding this comment.
Thank you, anyway I just copied what BART does here
There was a problem hiding this comment.
This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the deocder_input_ids from input_ids. For Marian users should pass either decoder_input_ids or labels and this is already handled by the MarianMTModel class.
There was a problem hiding this comment.
This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the
deocder_input_idsfrominput_ids. For Marian users should pass eitherdecoder_input_idsorlabelsand this is already handled by theMarianMTModelclass.
Thank you @patil-suraj for the reply, but as I said in this comment, transformers.onnx.convert.export() calls MarianMTModel.forward() without passing decoder_input_ids or labels, so how is this supposed to be handled?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Resolves #13823
@patil-suraj @michaelbenayoun